{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7d12d77e",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "#  20\\.  决策树模型参数优化及选择  # \n",
    "\n",
    "##  20.1.  介绍  # \n",
    "\n",
    "决策树的实验中，我们从头开始实现了完整的决策树分类过程。当然，实验的最后也介绍了使用 scikit-learn 完成决策树建模。实际上，由于决策树存在剪枝过程，所涉及的的参数非常多，本次挑战将带领你搞定机器学习模型参数优化及选择。 \n",
    "\n",
    "##  20.2.  知识点  # \n",
    "\n",
    "  * CART 决策树分类 \n",
    "\n",
    "  * 网格搜索参数选择 \n",
    "\n",
    "估计你早有这样的疑问，那就当我们构建机器学习模型时，怎样确定合适的参数？难道只能使用默认参数？或者盲目修改？ \n",
    "\n",
    "本次挑战将带你来找寻如何确定合适参数的答案。实际上，有的时候我们是能够估计模型参数或优化方法参数的大致范围，或者经过几次简单的手动修改来观测输出或评价指标变化，从而找到合适的参数。 \n",
    "\n",
    "但是，有的时候通过数次随机尝试来确定参数是不现实的。例如，决策树建模过程中涉及到的一些参数，最大深度，最大叶节点数等多个参数相互影响时，就变成了极其麻烦的排列组合。 \n",
    "\n",
    "下面，我们将介绍两种常用的超参数选择方法：网格搜索和随机参数。 \n",
    "\n",
    "##  20.3.  网格搜索  # \n",
    "\n",
    "网格搜索，简单来讲就是预先制定好各参数的有限个候选取值，然后通过排列组合的方式来传入这些参数，最终通过 K 折交叉验证的方法来确定表现最好的参数。 \n",
    "\n",
    "首先，我们来讲一下什么是 K 折交叉验证。K 折交叉验证是交叉验证中的一种常见方法，其通过将数据集均分成 K 个子集，并依次将其中的 K-1 个子集作为训练集，剩下的 1 个子集用作测试集。在 K 折交叉验证的过程中，每个子集均会被验证一次。 \n",
    "\n",
    "下面通过一张图示来解释 K 折交叉验证的过程： \n",
    "\n",
    "[ ![https://cdn.aibydoing.com/aibydoing/images/document-uid214893labid6671timestamp1531712016231.png](https://cdn.aibydoing.com/aibydoing/images/document-uid214893labid6671timestamp1531712016231.png) ](https://cdn.aibydoing.com/aibydoing/images/document-uid214893labid6671timestamp1531712016231.png)\n",
    "\n",
    "如上图所示，使用 K 折交叉验证的步骤如下： \n",
    "\n",
    "  1. 首先将数据集均分为 K 个子集。 \n",
    "\n",
    "  2. 依次选取其中的 K-1 个子集作为训练集，剩下的 1 个子集用作测试集进行实验。 \n",
    "\n",
    "  3. 计算每次验证结果的平均值作为最终结果。 \n",
    "\n",
    "相比于手动划分数据集，K 折交叉验证让每一条数据都有均等的几率被用于训练和验证，在一定程度上能提升模型的泛化能力。关于交叉验证，后面会有更多介绍，本次挑战只需要通过指定参数即可完成。 \n",
    "\n",
    "回到网格搜索的定义中。例如，模型有参数 A 和参数 B，我们指定参数 A 有  $P1$  ，  $P2$  ，  $P3$  等 3 个参数，参数 B 有  $P4$  ，  $P5$  ，  $P6$  等 3 个参数。那么，通过排列组合有 9 种不同的情况。于是，就可以通过遍历来测试不同参数组合下模型的表现，得到最佳结果。 \n",
    "\n",
    "![image](https://cdn.aibydoing.com/aibydoing/images/document-uid214893labid7506timestamp1547188175196.svg)\n",
    "\n",
    "下面，我们先建立一个决策树分类模型。这里使用 scikit-learn 提供的 digits 数据集，该数据集已经在前面介绍过了。当然后续你可以使用自己的数据集进行练习。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "855fa081",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_digits\n",
    "\n",
    "digits = load_digits()\n",
    "\n",
    "digits.data.shape, digits.target.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a7c9d08",
   "metadata": {},
   "outputs": [],
   "source": [
    "((1797, 64), (1797,))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f0a7cd2",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "接下来，使用 scikit-learn 提供的决策树算法实现分类。 \n",
    "\n",
    "Exercise 20.1 \n",
    "\n",
    "挑战：建立 CART 决策树完成分类，并得到 5 折交叉验证结果的平均分类准确度。 \n",
    "\n",
    "规定：除设置 ` random_state=42  ` 外，其他使用方法提供的默认参数。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d295a2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.model_selection import cross_val_score\n",
    "\n",
    "## 代码开始 ### (≈3 行代码)\n",
    "model = None\n",
    "## 代码结束 ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a734957",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "参考答案  Exercise 20.1 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8446bc61",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.model_selection import cross_val_score\n",
    "\n",
    "### 代码开始 ### (≈3 行代码)\n",
    "model = DecisionTreeClassifier(random_state=42)\n",
    "cvs = cross_val_score(model, digits.data, digits.target, cv=5)\n",
    "np.mean(cvs)\n",
    "### 代码结束 ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "741a92f3",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "期望输出 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e9dee05",
   "metadata": {},
   "outputs": [],
   "source": [
    "≈ 0.7903"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf566bb6",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "下面，我们准备使用网格搜索方法对决策树分类中常用的参数进行调参。通过上面的网格搜索示例可知，网格搜索可以通过循环搞定，但是这里我们直接使用 scikit-learn 提供的 ` GridSearchCV  ` 方法来实现。 \n",
    "\n",
    "Exercise 20.2 \n",
    "\n",
    "挑战：学习并使用 ` GridSearchCV  ` 完成网格搜索参数选择，并最终得到 5 等分交叉验证最佳结果。 \n",
    "\n",
    "规定：针对 CART 决策树 ` min_samples_split  ` 搜索候选参数 ` [2,  10,  20]  ` ，及 ` min_samples_leaf  ` 搜索候选参数 ` [1,  5,  10]  ` 。其他参数未特别指明，使用默认即可。 \n",
    "\n",
    "自行阅读学习： [ GridSearchCV 官方文档 ](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV) ｜ [ GridSearchCV 官方示例 ](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_digits.html#sphx-glr-auto-examples-model-selection-plot-grid-search-digits-py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfd59f5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "## 代码开始 ### (≈3 行代码)\n",
    "gs_model = None\n",
    "## 代码结束 ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fe640fb",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "参考答案  Exercise 20.2 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22b21b9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "### 代码开始 ### (≈3 行代码)\n",
    "# 需搜索参数字典\n",
    "tuned_parameters = {\"min_samples_split\": [2, 10, 20],\n",
    "                    \"min_samples_leaf\": [1, 5, 10]}\n",
    "\n",
    "# 网格搜索模型\n",
    "gs_model = GridSearchCV(model, tuned_parameters, cv=5)\n",
    "gs_model.fit(digits.data, digits.target)\n",
    "### 代码结束 ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5939125f",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "运行测试 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4fa2be7",
   "metadata": {},
   "outputs": [],
   "source": [
    "gs_model.best_score_  # 输出网格搜索交叉验证最佳结果"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e8afd61",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "期望输出 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c59c656",
   "metadata": {},
   "outputs": [],
   "source": [
    "≈ 0.790"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b012021f",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "实际上，你还可以通过模型的一些属性输出相关信息。例如下面查看网格搜索后的最佳参数。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f921dd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "gs_model.best_estimator_  # 查看网格搜索最佳参数"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4720703c",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "##  20.4.  随机搜索  # \n",
    "\n",
    "网格搜索很直观，也很方便，但是最大的问题在于随着候选参数增多，搜索需要的时间迅速增加。所以，有时候我们也会使用随机搜索的方法。 \n",
    "\n",
    "随机搜索，顾名思义就是经验 + 运气的碰撞。我们依据经验制定一个参数范围，然后在范围内随机选取参数测试，并返回最佳结果。例如下面，我们制定参数 A 在  $[P1, P3]$  区间，参数 B 在  $[P4, P6]$  区间变化。 \n",
    "\n",
    "![image](https://cdn.aibydoing.com/aibydoing/images/document-uid214893labid7506timestamp1547188174821.svg)\n",
    "\n",
    "scikit-learn 同样提供的 ` RandomizedSearchCV  ` 方法来实现随机搜索。 \n",
    "\n",
    "Exercise 20.3 \n",
    "\n",
    "挑战：学习并使用 ` RandomizedSearchCV  ` 完成网格搜索参数选择，并最终得到 5 等分交叉验证最佳结果。 \n",
    "\n",
    "规定：针对 CART 决策树 ` min_samples_split  ` 搜索候选参数区间 ` (2,  20)  ` ，及 ` min_samples_leaf  ` 搜索候选参数区间 ` (1,  10)  ` ，并随机搜索 10 组参数。其他参数未特别指明，使用默认即可。 \n",
    "\n",
    "自行阅读学习： [ RandomizedSearchCV 官方文档 ](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html#sklearn.model_selection.RandomizedSearchCV) ｜ [ RandomizedSearchCV 官方示例 ](https://scikit-learn.org/stable/auto_examples/model_selection/plot_randomized_search.html#sphx-glr-auto-examples-model-selection-plot-randomized-search-py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e939c79",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import randint\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "\n",
    "## 代码开始 ### (≈3 行代码)\n",
    "rs_model = None\n",
    "## 代码结束 ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe6fe23c",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "参考答案  Exercise 20.3 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74e8a426",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "from scipy.stats import randint\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "\n",
    "### 代码开始 ### (≈3 行代码)\n",
    "# 需搜索参数字典\n",
    "tuned_parameters = {\"min_samples_split\": randint(2, 20),\n",
    "                    \"min_samples_leaf\": randint(1, 10)}\n",
    "# 随机搜索模型\n",
    "rs_model = RandomizedSearchCV(model, tuned_parameters, n_iter=10, cv=5)\n",
    "rs_model.fit(digits.data, digits.target)\n",
    "### 代码结束 ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1d67965",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "rs_model.best_score_  # 输出网格搜索交叉验证最佳结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cca6cf3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "rs_model.best_estimator_  # 查看网格搜索最佳参数"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c757585",
   "metadata": {},
   "source": [
    "由于随机搜索每次结果不同，所以此处没有期望输出。 \n",
    "\n",
    "你可能会发现，我们使用网格搜索或随机搜索确定的参数往往还没有一开始的默认参数好。其实，scikit-learn 迭代到今天已经非常成熟，部分默认参数都是许多人在日常使用中总结而来的，所以往往默认参数表现就非常不错。调参在机器学习建模中很重要，但数据和相关预处理方法对于模型的最终表现往往更加重要。 "
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
